{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d3dd7f09",
   "metadata": {},
   "source": [
    "双语互译质量评估辅助工具：BLEU(bilingual evaluation understudy)，.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a592fe40",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e61b866",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建字典\n",
    "seq_data = [['man', 'women'], ['black', 'white'], ['king', 'queen'], [\n",
    "    'girl', 'boy'], ['up', 'down'], ['high', 'low']]\n",
    "char_arr = [c for c in 'SEPabcdefghijklmnopqrstuvwxyz']\n",
    "num_dict = {n: i for i, n in enumerate(char_arr)}\n",
    "\n",
    "# 网络参数\n",
    "n_step = 5\n",
    "n_hidden = 128\n",
    "n_class = len(num_dict)\n",
    "batch_size = len(seq_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4edafbb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 准备数据\n",
    "def make_batch(seq_data):\n",
    "    input_batch, output_batch, target_batch = [], [], []\n",
    "\n",
    "    for seq in seq_data:\n",
    "        for i in range(2):\n",
    "            seq[i] = seq[i] + 'P' * (n_step-len(seq[i]))\n",
    "        input = [num_dict[n] for n in seq[0]]\n",
    "        ouput = [num_dict[n] for n in ('S' + seq[1])]\n",
    "        target = [num_dict[n] for n in (seq[1]) + 'E']\n",
    "\n",
    "        input_batch.append(np.eye(n_class)[input])\n",
    "        output_batch.append(np.eye(n_class)[ouput])\n",
    "        target_batch.append(target)\n",
    "\n",
    "    return torch.Tensor(input_batch), torch.Tensor(output_batch), torch.LongTensor(target_batch)\n",
    "\n",
    "\n",
    "input_batch, output_batch, target_batch = make_batch(seq_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7651595",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建网络\n",
    "class Seq2Seq(nn.Module):\n",
    "    \"\"\"\n",
    "    要点：\n",
    "    1.该网络包含一个encoder和一个decoder，使用的RNN的结构相同，最后使用全连接接预测结果\n",
    "    2.RNN网络结构要熟知\n",
    "    3.seq2seq的精髓：encoder层生成的参数作为decoder层的输入\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        # 此处的input_size是每一个节点可接纳的状态，hidden_size是隐藏节点的维度\n",
    "        self.enc = nn.RNN(input_size=n_class,\n",
    "                          hidden_size=n_hidden, dropout=0.5)\n",
    "        self.dec = nn.RNN(input_size=n_class,\n",
    "                          hidden_size=n_hidden, dropout=0.5)\n",
    "        self.fc = nn.Linear(n_hidden, n_class)\n",
    "\n",
    "    def forward(self, enc_input, enc_hidden, dec_input):\n",
    "        # RNN要求输入：(seq_len, batch_size, n_class)，这里需要转置一下\n",
    "        enc_input = enc_input.transpose(0, 1)\n",
    "        dec_input = dec_input.transpose(0, 1)\n",
    "        _, enc_states = self.enc(enc_input, enc_hidden)\n",
    "        outputs, _ = self.dec(dec_input, enc_states)\n",
    "        pred = self.fc(outputs)\n",
    "\n",
    "        return pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9810a232",
   "metadata": {},
   "outputs": [],
   "source": [
    "# training\n",
    "model = Seq2Seq()\n",
    "loss_fun = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "for epoch in range(5000):\n",
    "    hidden = torch.zeros(1, batch_size, n_hidden)\n",
    "\n",
    "    optimizer.zero_grad()\n",
    "    pred = model(input_batch, hidden, output_batch)\n",
    "    pred = pred.transpose(0, 1)\n",
    "    loss = 0\n",
    "    for i in range(len(seq_data)):\n",
    "        temp = pred[i]\n",
    "        tar = target_batch[i]\n",
    "        loss += loss_fun(pred[i], target_batch[i])\n",
    "    if (epoch + 1) % 1000 == 0:\n",
    "        print('Epoch: %d   Cost: %f' % (epoch + 1, loss))\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cc1d4a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 测试\n",
    "def translate(word):\n",
    "    input_batch, output_batch, _ = make_batch([[word, 'P' * len(word)]])\n",
    "    # hidden 形状 (1, 1, n_class)\n",
    "    hidden = torch.zeros(1, 1, n_hidden)\n",
    "    # output 形状（6，1， n_class)\n",
    "    output = model(input_batch, hidden, output_batch)\n",
    "    predict = output.data.max(2, keepdim=True)[1]\n",
    "    decoded = [char_arr[i] for i in predict]\n",
    "    end = decoded.index('E')\n",
    "    translated = ''.join(decoded[:end])\n",
    "\n",
    "    return translated.replace('P', '')\n",
    "\n",
    "\n",
    "print('girl ->', translate('girl'))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
